{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial - Hugging Face Hub model sharing 🤗\n",
    "\n",
    "In this notebook, we will see how to share your models with the community using the integrated Hugging Face Hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install the library\n",
    "%pip install pythae"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train your Pythae model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.datasets as datasets\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=None)\n",
    "\n",
    "train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.\n",
    "eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from pythae.models import BetaVAE, BetaVAEConfig\n",
    "from pythae.trainers import BaseTrainerConfig\n",
    "from pythae.pipelines.training import TrainingPipeline\n",
    "from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_VAE_MNIST, Decoder_ResNet_AE_MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = BaseTrainerConfig(\n",
    "    output_dir='my_model',\n",
    "    learning_rate=1e-4,\n",
    "    per_device_train_batch_size=64,\n",
    "    per_device_eval_batch_size=64,\n",
    "    num_epochs=1, # Change this to train the model a bit more\n",
    ")\n",
    "\n",
    "\n",
    "model_config = BetaVAEConfig(\n",
    "    input_dim=(1, 28, 28),\n",
    "    latent_dim=16,\n",
    "    beta=2.\n",
    "\n",
    ")\n",
    "\n",
    "model = BetaVAE(\n",
    "    model_config=model_config,\n",
    "    encoder=Encoder_ResNet_VAE_MNIST(model_config), \n",
    "    decoder=Decoder_ResNet_AE_MNIST(model_config) \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline = TrainingPipeline(\n",
    "    training_config=config,\n",
    "    model=model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline(\n",
    "    train_data=train_dataset,\n",
    "    eval_data=eval_dataset\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reload your trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pythae.models import AutoModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "last_training = sorted(os.listdir('my_model'))[-1]\n",
    "trained_model = AutoModel.load_from_folder(os.path.join('my_model', last_training, 'final_model'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Now let's share your model to the community through the Hugging Face hub! 🤗"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To be able to access this feature you will need:\n",
    "- a valid *username* from your Hugging Face account.\n",
    "- the `huggingface_hub` package installed in your virtual env. You can install it by running (`$ python -m pip install huggingface_hub`)\n",
    "- to be logged in to your hugginface account by running (`$ huggingface-cli login`)\n",
    "\n",
    "**note**: If the repo you specified is not empty, its content will be overidden. If the repo does not exist it will be created automatically under the name that was specified."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Before pushing or loading a model from the Hub you may need to run the following\n",
    "# !python -m pip install huggingface_hub\n",
    "# !huggingface-cli login"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the model to the Hub by specifying your username and the name of the repo in which you want to save your model\n",
    "trained_model.push_to_hf_hub(\"your_hf_username/my_beta_vae\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_model_from_hf = AutoModel.load_from_hf_hub(\"your_hf_username/my_beta_vae\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check that everything went well\n",
    "assert all(\n",
    "            [\n",
    "                torch.equal(trained_model.state_dict()[key], trained_model_from_hf.state_dict()[key])\n",
    "                for key in model.state_dict().keys()\n",
    "            ]\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Use your model to do whatever you want"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.samplers import NormalSampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create normal sampler\n",
    "normal_samper = NormalSampler(\n",
    "    model=trained_model_from_hf\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample\n",
    "gen_data = normal_samper.sample(\n",
    "    num_samples=25\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show results with normal sampler\n",
    "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
    "\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')\n",
    "        axes[i][j].axis('off')\n",
    "plt.tight_layout(pad=0.)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "3efa06c4da850a09a4898b773c7e91b0da3286dbbffa369a8099a14a8fa43098"
  },
  "kernelspec": {
   "display_name": "Python 3.8.11 64-bit ('pythae_dev': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
